name: tpuvm_mnist

resources:
  accelerators: tpu-v2-8

# The setup command.  Will be run under the working directory.
setup: |
  git clone https://github.com/google/flax.git --branch v0.6.11

  conda activate flax
  if [ $? -eq 0 ]; then
    echo 'conda env exists'
  else
    conda create -n flax python=3.10 -y
    conda activate flax
    # Make sure to install TPU related packages in a conda env to avoid package conflicts.
    pip install "jax[tpu]==0.4.23" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip install --upgrade clu
    pip install -e flax
    pip install tensorflow tensorflow-datasets
  fi


# The command to run.  Will be run under the working directory.
run: |
  conda activate flax
  cd flax/examples/mnist
  python3 main.py --workdir=/tmp/mnist \
  --config=configs/default.py \
  --config.learning_rate=0.05 \
  --config.num_epochs=10
